{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import nltk\n",
    "import sklearn\n",
    "import random\n",
    "import re\n",
    "from sklearn.feature_extraction.text import CountVectorizer,TfidfTransformer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score\n",
    "from sklearn.linear_model import LogisticRegression, LinearRegression\n",
    "from sklearn.naive_bayes import GaussianNB"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reuters 21-578 modApte version\n",
    "> a collection of 10,788 documents from the Reuters financial newswire service, partitioned into a training set with 7769 documents and a test set with 3019 documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package reuters to /home/felipe/nltk_data...\n",
      "[nltk_data]   Package reuters is already up-to-date!\n",
      "[nltk_data] Downloading package punkt to /home/felipe/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nltk.download('reuters')\n",
    "nltk.download('punkt') # needed for tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ZipFilePathPointer(u'/home/felipe/nltk_data/corpora/reuters.zip', u'reuters/')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = nltk.corpus.reuters\n",
    "dataset.root"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# dataset.readme()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "90"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset.categories())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10788"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset.fileids())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'training/3482'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fileids = dataset.fileids()\n",
    "sample_fileid = [ fileids[i] for i in sorted(random.sample(xrange(len(fileids)), 1)) ][0]\n",
    "sample_fileid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ZipFilePathPointer(u'/home/felipe/nltk_data/corpora/reuters.zip', u'reuters/training/3482')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.abspath(sample_fileid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "101"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset.words(sample_fileid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[u'U', u'.', u'K', u'.', u'MONEY', u'MARKET', ...]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.words(sample_fileid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "u\"U.K. MONEY MARKET SHORTAGE FORECAST AT 300 MLN STG\\n  The Bank of England said it forecast a\\n  liquidity shortage of around 300 mln stg in the market today.\\n      Among the main factors, the Bank said bills maturing in\\n  official hands and the treasury bill take-up would drain 483\\n  mln stg from the system while below target bankers' balances\\n  and a rise in the note circulation would take out 50 mln and\\n  100 mln stg respectively.\\n      Partially offsetting these, exchequer transactions would\\n  add around 355 mln stg, the Bank added.\\n  \\n\\n\""
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.raw(sample_fileid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[u'U', u'.', u'K', u'.', u'MONEY', u'MARKET', ...]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.words(sample_fileid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[u'U', u'.', u'K', u'.', u'MONEY', u'MARKET', u'SHORTAGE', u'FORECAST', u'AT', u'300', u'MLN', u'STG', u'The', u'Bank', u'of', u'England', u'said', u'it', u'forecast', u'a', u'liquidity', u'shortage', u'of', u'around', u'300', u'mln', u'stg', u'in', u'the', u'market', u'today', u'.'], [u'Among', u'the', u'main', u'factors', u',', u'the', u'Bank', u'said', u'bills', u'maturing', u'in', u'official', u'hands', u'and', u'the', u'treasury', u'bill', u'take', u'-', u'up', u'would', u'drain', u'483', u'mln', u'stg', u'from', u'the', u'system', u'while', u'below', u'target', u'bankers', u\"'\", u'balances', u'and', u'a', u'rise', u'in', u'the', u'note', u'circulation', u'would', u'take', u'out', u'50', u'mln', u'and', u'100', u'mln', u'stg', u'respectively', u'.'], ...]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.sents(sample_fileid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[[u'U', u'.', u'K', u'.', u'MONEY', u'MARKET', u'SHORTAGE', u'FORECAST', u'AT', u'300', u'MLN', u'STG', u'The', u'Bank', u'of', u'England', u'said', u'it', u'forecast', u'a', u'liquidity', u'shortage', u'of', u'around', u'300', u'mln', u'stg', u'in', u'the', u'market', u'today', u'.'], [u'Among', u'the', u'main', u'factors', u',', u'the', u'Bank', u'said', u'bills', u'maturing', u'in', u'official', u'hands', u'and', u'the', u'treasury', u'bill', u'take', u'-', u'up', u'would', u'drain', u'483', u'mln', u'stg', u'from', u'the', u'system', u'while', u'below', u'target', u'bankers', u\"'\", u'balances', u'and', u'a', u'rise', u'in', u'the', u'note', u'circulation', u'would', u'take', u'out', u'50', u'mln', u'and', u'100', u'mln', u'stg', u'respectively', u'.'], [u'Partially', u'offsetting', u'these', u',', u'exchequer', u'transactions', u'would', u'add', u'around', u'355', u'mln', u'stg', u',', u'the', u'Bank', u'added', u'.']]]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.paras(sample_fileid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# http://scikit-learn.org/stable/modules/feature_extraction.html#text-feature-extraction\n",
    "corpus_train = []\n",
    "corpus_test = []\n",
    "for fileid in dataset.fileids():\n",
    "    document = dataset.raw(fileid)\n",
    "    if re.match('training/',fileid):\n",
    "        corpus_train.append(document)\n",
    "    else:\n",
    "        corpus_test.append(document)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7769, 3019)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(corpus_train),len(corpus_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def preprocessor(string):\n",
    "    repl = re.sub('&lt;','',string)\n",
    "    return repl.lower()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "vectorizer = CountVectorizer(\n",
    "                min_df=10, # tweaking this parameter reduces the length of the feature vector\n",
    "                strip_accents='ascii',\n",
    "                preprocessor=preprocessor,\n",
    "                stop_words='english')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((7769, 6462), (3019, 6462), (10788, 6462))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# need to use both corpuses for fitting because otherwise there may be words that only occur in the\n",
    "# training set or in the test set\n",
    "full_corpus = corpus_train + corpus_test\n",
    "vectorizer.fit(full_corpus)\n",
    "\n",
    "X_train_counts = vectorizer.transform(corpus_train)\n",
    "X_test_counts = vectorizer.transform(corpus_test)\n",
    "X_full_counts = vectorizer.transform(full_corpus)\n",
    "\n",
    "X_train_counts.shape,X_test_counts.shape, X_full_counts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#uncomment these to see how the vectorizer is analyzing, tokenizing and preprocessing documents\n",
    "\n",
    "#vectorizer.build_analyzer()(dataset.raw(fileid))\n",
    "#vectorizer.build_tokenizer()(\"ADVANCED INSTITUTIONAL &lt;AIMS> CUTS WORKFORCE\\n  Advanced Institutional \")\n",
    "#vectorizer.build_preprocessor()(\"ADVANCED INSTITUTIONAL &lt;AIMS> CUTS WORKFORCE\\n  Advanced Institutional \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 1, 0, 0])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_counts[0].toarray().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 0, 0, 0])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test_counts[0].toarray().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((7769, 6462), (3019, 6462), (10788, 6462))"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformer = TfidfTransformer()\n",
    "# again, we need to fit the transformer to all documents (train and test)\n",
    "transformer.fit(X_full_counts)\n",
    "\n",
    "X_train_tfidf = transformer.transform(X_train_counts)\n",
    "X_test_tfidf = transformer.transform(X_test_counts)\n",
    "X_full_tfidf = transformer.transform(X_full_counts)\n",
    "\n",
    "X_train_tfidf.shape, X_test_tfidf.shape, X_full_tfidf.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.       ,  0.       ,  0.       , ...,  0.0466051,  0.       ,  0.       ])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train_tfidf[0].toarray().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.,  0.,  0., ...,  0.,  0.,  0.])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test_tfidf[0].toarray().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((7769, 90), (3019, 90))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_train = []\n",
    "Y_test = []\n",
    "\n",
    "for (idx,fileid) in enumerate(dataset.fileids()):    \n",
    "    categories = '*'.join(dataset.categories(fileid))\n",
    "\n",
    "    if re.match('training/',fileid):\n",
    "        Y_train.append(categories)\n",
    "    else:\n",
    "        Y_test.append(categories)\n",
    "\n",
    "series_train = pd.Series(Y_train)\n",
    "Y_train_df = series_train.str.get_dummies(sep='*')\n",
    "\n",
    "series_test = pd.Series(Y_test)\n",
    "Y_test_df = series_test.str.get_dummies(sep='*')\n",
    "\n",
    "Y_train = Y_train_df.values\n",
    "Y_test = Y_test_df.values\n",
    "\n",
    "Y_train.shape,Y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6.15 s, sys: 3.64 ms, total: 6.16 s\n",
      "Wall time: 6.18 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "clf = LogisticRegression()\n",
    "meta_clf = OneVsRestClassifier(clf)\n",
    "\n",
    "meta_clf.fit(X_train_tfidf,Y_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "Y_pred = meta_clf.predict(X_test_tfidf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.76201298701298703"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_score(Y_test,Y_pred,average='micro')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
